from typing import Optional
import torch
import torch.nn.functional as F
from indexer_topk_reducesum import indexer_topk_reducesum_interface
from indexer_bwd import indexer_bwd_interface
from sparse_mla_fwd import sparse_mla_fwd_interface
from sparse_mla_bwd import sparse_mla_bwd
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
from einops import einsum, repeat
from utils import get_abs_err, get_err_ratio


class RegsiterLossFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, loss):
        ctx.save_for_backward(loss)
        return x

    @staticmethod
    def backward(ctx, grad):
        loss = ctx.saved_tensors
        return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)


register_loss = RegsiterLossFunction.apply


def ref_deepseek_sparse_attention_innner(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
    index_sm_scale: Optional[float] = None,
):
    dtype = q.dtype
    q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32),
                                           (q, kv, index_q, index_k, weights))

    index_sm_scale = index_q.shape[-1]**-0.5
    b, s = index_q.shape[:2]

    # tl_topk_indices = tl_topk_indices.to(torch.int64)
    # tl_topk_indices[tl_topk_indices == -1] = s

    casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
    index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2')
    index_logits = F.relu(index_logits)
    index_logits = (index_logits * weights.unsqueeze(-1)).sum(
        dim=-2, dtype=torch.float32) * index_sm_scale
    index_logits = torch.where(casual_mask, index_logits, float('-inf'))
    topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
    topk_logits = torch.gather(
        F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices)
    topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
    index_topk_score = topk_score

    if sm_scale is None:
        sm_scale = kv.shape[-1]**-0.5

    h = q.shape[-2]
    index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\
        .scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1]
    mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h)
    k, v = kv, kv[..., :dim_v]
    logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale
    logits = torch.where(mask, logits, float('-inf'))
    attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
    o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d')

    attn_score = attn_score.sum(dim=-2)  # [b, s1, s2]
    attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
    attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)

    loss = F.kl_div(
        index_topk_score.clip(-100, 0),
        attn_topk_score.detach().log().clip(-100, 0),
        log_target=True,
        reduction="sum")
    o = register_loss(o, loss)

    return o.to(dtype), topk_indices


def ref_deepseek_sparse_attention(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    offsets: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
    index_sm_scale: Optional[float] = None,
):
    all_o, all_topk_indices = [], []
    for i in range(offsets.shape[0] - 1):
        o, topk_indices = ref_deepseek_sparse_attention_innner(
            q[None, offsets[i]:offsets[i + 1]],
            kv[None, offsets[i]:offsets[i + 1]],
            index_q[None, offsets[i]:offsets[i + 1]],
            index_k[None, offsets[i]:offsets[i + 1]],
            weights[None, offsets[i]:offsets[i + 1]],
            topk,
            dim_v,
            sm_scale,
            index_sm_scale,
        )
        all_o.append(o.squeeze(0))
        all_topk_indices.append(topk_indices.squeeze(0))
    o = torch.cat(all_o, dim=0)
    topk_indices = torch.cat(all_topk_indices, dim=0)
    return o, topk_indices


class DSAFunction(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,
        kv: torch.Tensor,
        index_q: torch.Tensor,
        index_k: torch.Tensor,
        weights: torch.Tensor,
        offsets: torch.Tensor,
        topk: int,
        dim_v: int,
        sm_scale: Optional[float] = None,
    ):
        # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
        topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k,
                                                                     topk, offsets)
        o, lse = sparse_mla_fwd_interface(
            q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
        ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse,
                              offsets)
        ctx.topk = topk
        ctx.dim_v = dim_v
        ctx.sm_scale = sm_scale
        return o, topk_indices

    @staticmethod
    def backward(
        ctx,
        do: torch.Tensor,
        _1: torch.Tensor,
    ):
        q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
        attn_score = sparse_mla_topk_reducesum_interface(
            q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets,
            dim_v=ctx.dim_v).squeeze(-2)
        dq, dkv = sparse_mla_bwd(
            q,
            kv.unsqueeze(-2),
            o,
            do,
            topk_indices.unsqueeze(-2),
            lse,
            offsets,
            sm_scale=ctx.sm_scale)
        dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score,
                                                             index_score, topk_indices, offsets)
        return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None


def deepseek_sparse_attention(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    offsets: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
):
    return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=512,
    tail_D=64,
    index_D=128,
    topk=64,
):
    torch.manual_seed(42)
    q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
    kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
    index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
    weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
    index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
    do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
    offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()

    o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
    o.backward(do)
    q_grad, q.grad = q.grad, None
    kv_grad, kv.grad = kv.grad, None
    index_q_grad, index_q.grad = index_q.grad, None
    index_k_grad, index_k.grad = index_k.grad, None
    weights_grad, weights.grad = weights.grad, None

    ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights,
                                                            offsets, topk, D)
    ref_o.backward(do)
    ref_q_grad, q.grad = q.grad, None
    ref_kv_grad, kv.grad = kv.grad, None
    ref_index_q_grad, index_q.grad = index_q.grad, None
    ref_index_k_grad, index_k.grad = index_k.grad, None
    ref_weights_grad, weights.grad = weights.grad, None

    print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
    print(
        f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}"
    )
    print(
        f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}"
    )
    print(
        f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
    )
    print(
        f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}"
    )
    print(
        f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}"
    )

    intersections = []
    for j in range(S):
        ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
        trt_np = topk_indices[j].cpu().to(torch.int32).numpy()

        mask = (trt_np != -1)

        set_ref = set(ref_np[mask])
        set_trt = set(trt_np[mask])
        intersection = set_ref & set_trt
        intersections.append(len(intersection) / len(set_ref))
    print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))


test_kernel()
